Graph Safe Current Scaling Support for GroupedLinear Module/Ops + Fix CUBLAS GGEMM heuristics#3143
Graph Safe Current Scaling Support for GroupedLinear Module/Ops + Fix CUBLAS GGEMM heuristics#3143vthumbe1503 wants to merge 18 commits into
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Removed details about FP8 current scaling methods. Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
/te-ci pytorch |
Greptile SummaryThis PR adds graph-safe FP8 per-tensor current scaling support to
Confidence Score: 5/5Safe to merge; the correctness fixes are well-targeted and the cuBLAS heuristic changes affect only algorithm selection, not GEMM output. The two substantive correctness fixes are straightforward and verified by the new test parametrizations. The cuBLAS avg_m/avg_n/avg_k changes are heuristics fed to the kernel selector and do not affect GEMM output correctness. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu — specifically the nvte_grouped_gemm_with_discrete_out heuristic derivation, which is used for weight-gradient GEMMs and may now produce less accurate M/N/K estimates than the old code did for that call pattern. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[GroupedLinear Forward Call] --> B{_is_graph_safe_path_supported?}
B -- CC less than 9.0 --> C[Legacy split_quantize path]
B -- CC 9.0 to 11.0 --> D{with_quantized_compute?}
D -- No --> E{dtype BF16/FP16?}
E -- Yes --> F[Grouped Tensor Path BF16/FP16]
E -- No --> C
D -- Yes --> G{All Float8CurrentScalingQuantizer?}
G -- Yes --> H[Grouped Tensor Path FP8 Current Scaling NEW]
G -- No --> I{CC 10.0 to 11.0?}
I -- No --> C
I -- Yes --> J{All MXFP8Quantizer?}
J -- Yes --> K[Grouped Tensor Path MXFP8]
J -- No --> L{All NVFP4+RHT AND NOT single_grouped_weight?}
L -- Yes --> M[Grouped Tensor Path NVFP4]
L -- No --> C
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[GroupedLinear Forward Call] --> B{_is_graph_safe_path_supported?}
B -- CC less than 9.0 --> C[Legacy split_quantize path]
B -- CC 9.0 to 11.0 --> D{with_quantized_compute?}
D -- No --> E{dtype BF16/FP16?}
E -- Yes --> F[Grouped Tensor Path BF16/FP16]
E -- No --> C
D -- Yes --> G{All Float8CurrentScalingQuantizer?}
G -- Yes --> H[Grouped Tensor Path FP8 Current Scaling NEW]
G -- No --> I{CC 10.0 to 11.0?}
I -- No --> C
I -- Yes --> J{All MXFP8Quantizer?}
J -- Yes --> K[Grouped Tensor Path MXFP8]
J -- No --> L{All NVFP4+RHT AND NOT single_grouped_weight?}
L -- Yes --> M[Grouped Tensor Path NVFP4]
L -- No --> C
Reviews (10): Last reviewed commit: "fix m and n" | Re-trigger Greptile |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
… weight being cuda graphable Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…3/TransformerEngine into nvfp4_and_fp8_current_scaling
|
/te-ci pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci pytorch |
denera
left a comment
There was a problem hiding this comment.
LGTM except for two minor fixes/clarifications in the GroupedMLP tests.
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…3/TransformerEngine into nvfp4_and_fp8_current_scaling
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci pytorch |
timmoon10
left a comment
There was a problem hiding this comment.
LGTM, pending CI and perf checks.
|
/te-ci |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: